""" Utilities for the run scripts

"""
import itertools
import copy
import pandas as pd
import numpy as np
import json
try:
    from src.interpolation.splines import UCGrid
except:
    from ..interpolation.splines import UCGrid
import statsmodels.formula.api as smf

def read_in_data(
    path_moments_data='./models/smm_input/moments_empirical.csv',
    path_weights='./models/smm_input/weights.csv',
    path_n_rigs="./models/first_stage/n_rigs",
    path_surplus_components="./models/surplus/surplus_components",
    path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
    path_df_state='./data_py/processed/states',
    path_delta='./models/smm_input/delta.csv',
    path_entry_cost='./models/smm_input/entry_cost.csv',
    path_rho='./models/smm_input/rho.csv',
    path_df_contracts='./data_py/processed/contracts_final.csv',
    path_price_match_values='./models/price_match/price_match_values',
    path_coefs_data='./models/smm_input/coefs_data',
    path_prob_match_predict_contracts='./models/robustness/prob_match_predict_contracts',
    path_prob_match_predict='./models/robustness/prob_match_predict',
    time_period='month',
    p_exit=0.0,
    use_myopic=False,
    bootstrap_seed=''):

    # %% READ IN THE INPUTS -------------------------------------------------------------
    moments_data = pd.read_csv(path_moments_data, index_col=[0])
    moments_data = moments_data['0']
    with open(f"{path_n_rigs}_{time_period}.json") as f:
        n_rigs = json.load(f)

    match_values_by_tau_spec = dict()
    for tau, spec in itertools.product([2, 3, 4], ['low', 'mid', 'high']):
        match_values_by_tau_spec[(tau, spec)] = np.loadtxt(
            f'{path_surplus_components}_{tau}_{spec}_{time_period}.txt'
        )

    # Using that the grid for surplus is the same for each (tau, spec)
    surplus_grid_params = np.load(path_surplus_grid)
    surplus_grid = UCGrid(
        tuple(surplus_grid_params[0, :]),
        tuple(surplus_grid_params[1, :]),
        tuple(surplus_grid_params[2, :]),
        tuple(surplus_grid_params[3, :]),
        tuple(surplus_grid_params[4, :])
    )

    # Get state data
    df_state = pd.read_csv(f'{path_df_state}_{time_period}.csv', index_col=[0])
    df_state['date'] = pd.to_datetime(df_state['date'])
    state_data = df_state[['date', 'g', 'n_l', 'n_m', 'n_h']]
    g_data = state_data[['g', 'date']]
    g_data['2006'] = (g_data['date'].dt.year == 2006)
    g_cutoff = g_data['g'].mean()

    # %% SETUP FIXED INPUTS -------------------------------------------------------------
    delta = pd.read_csv(path_delta, index_col=[0])['0']
    c = pd.read_csv(path_entry_cost, index_col=[0])['0']
    rho = pd.read_csv(path_rho, index_col=[0])['0']

    # %% GET THE PRICE-MATCH MOMENTS ----------------------------------------------------
    df_contracts = pd.read_csv(path_df_contracts, index_col=[0])
    mean_reneg_data = df_contracts['reneg'].mean()

    values_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        values_by_spec[spec] = df_contracts.loc[df_contracts['spec'] == spec, 'value']

    price_match_values_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        price_match_values_by_spec[spec] = dict()
        for k in [0, 1, 2, 3]:
            with open(
                    f"{path_price_match_values}_{spec}_{k}_{time_period}{bootstrap_seed}.json") as f:
                price_match_values_by_spec[spec][k] = np.array(json.load(f))

    coefs_data = pd.read_csv(
        f'{path_coefs_data}_{time_period}{bootstrap_seed}.csv', index_col=[0]
    ).squeeze()

    # %% ADD IN NON-MYOPIC DICT ---------------------------------------------------------
    if not use_myopic:
        non_myopic_dict = dict()
        non_myopic_dict['prob_exit'] = p_exit
        non_myopic_dict['prob_match_contracts'] = dict()
        non_myopic_dict['prob_match'] = dict()
        for spec in ['low', 'mid', 'high']:
            non_myopic_dict['prob_match_contracts'][spec] = np.array(pd.read_csv(
                f'{path_prob_match_predict_contracts}_{spec}.csv',
                index_col=[0]
            )).T[0]
            non_myopic_dict['prob_match'][spec] = np.array(pd.read_csv(
                f'{path_prob_match_predict}_{spec}.csv',
                index_col=[0]
            )).T[0]
    else:
        non_myopic_dict = None

    # %% ADD IN WEIGHTS -----------------------------------------------------------------
    weights = pd.read_csv(
        path_weights, index_col=0, header=None
    ).squeeze("columns").to_dict()
    data = {
        'state_data': state_data,
        'moments_data': moments_data,
        'match_values_by_tau_spec': match_values_by_tau_spec,
        'surplus_grid': surplus_grid,
        'n_rigs': n_rigs,
        'g_data': g_data,
        'g_cutoff': g_cutoff,
        'price_match_values_by_spec': price_match_values_by_spec,
        'values_by_spec': values_by_spec,
        'df_contracts': df_contracts,
        'coefs_data': coefs_data,
        'mean_reneg_data': mean_reneg_data,
        #'verbose': False,
        #'verbose_output': False,
        'non_myopic_dict': non_myopic_dict,
        'gas_price': df_state['gas']
    }

    return data, delta, rho, c, weights


def get_aggregated_moments_from_data(df_contracts, df_state):
    """ Get the aggregate moments. Function so it can be used for
    bootstrapping.

    Args:
        df_contracts: contracts data
        df_state: state data

    Returns:
        moments_agg: moments to be used computed from the data

    """
    # Get utilization moments
    moments_agg = dict()
    for spec in ['low', 'mid', 'high']:
        moments_agg[f'utilization_covariance_{spec}'] = \
            np.cov(df_state[f'utilization_{spec}'], df_state['g'])[0, 1]
        moments_agg[f'utilization_variance_{spec}'] = \
            np.var(df_state[f'utilization_{spec}'])
        moments_agg[f'utilization_mean_{spec}'] = \
            np.mean(df_state[f'utilization_{spec}'])
        moments_agg[f'util_2006_{spec}'] = \
            np.mean(df_state.loc[df_state['2006'], f'utilization_{spec}'])

    # Get contract length moments
    contract_total = dict()
    contract_total_denom = 0.0
    for tau in [2, 3, 4]:
        # Only look at new contracts (i.e. reneg == 0)
        contract_total[f'p_{tau}'] = len(df_contracts.loc[(
            (df_contracts['spec'] == spec)
            & (df_contracts['reneg'] == 0)
            & (df_contracts['tau'] == tau)
        )])
        contract_total_denom += contract_total[f'p_{tau}']
    moments_agg['p_2'] = contract_total['p_2'] / contract_total_denom
    moments_agg['p_3'] = contract_total['p_3'] / contract_total_denom

    # Get complexity moments
    for spec in ['low', 'mid', 'high']:
        moments_agg[f'mri_mean_{spec}_boom'] = df_contracts.loc[(
            (df_contracts['boom'] == True)
            & (df_contracts['spec'] == spec)
        ), 'mri'].mean()
        moments_agg[f'mri_mean_{spec}_bust'] = df_contracts.loc[(
            (df_contracts['boom'] == False)
            & (df_contracts['spec'] == spec)
        ), 'mri'].mean()

    moments_agg['mri_variance'] = np.var(df_contracts['mri'])

    return moments_agg


def get_price_coefficients(df_contracts, price_match_values_by_spec, delta):
    # %% SETUP DATA ---------------------------------------------------------------------
    df_contracts['outside'] = 0.0
    df_contracts.loc[df_contracts['spec'] == 'low', 'outside'] = \
        (1 - delta) * price_match_values_by_spec['low']
    df_contracts.loc[df_contracts['spec'] == 'mid', 'outside'] = \
        (1 - delta) * price_match_values_by_spec['mid']
    df_contracts.loc[df_contracts['spec'] == 'high', 'outside'] = \
        (1 - delta) * price_match_values_by_spec['high']

    # %% GET COEFFICIENTS ---------------------------------------------------------------
    df_contracts['diff_dayrate'] = df_contracts['day_rate'] - df_contracts['outside']
    # df_contracts['g:value'] = df_contracts['g'] * df_contracts['value']
    formula = 'diff_dayrate ' \
              '~ C(spec, Treatment(reference="low")) ' \
              '+ mri : C(spec, Treatment(reference="low")) ' \
              '+ g : value'

    reg = smf.ols(
        formula=formula,
        data=df_contracts
    ).fit()
    coefs_data = reg.params.T

    # %% GET AVERAGE PRICES -------------------------------------------------------------
    for spec in ['low', 'mid', 'high']:
        coefs_data[f'price_{spec}'] = \
            df_contracts.loc[df_contracts['spec'] == spec, 'day_rate'].mean()

    return coefs_data


def get_entry_cost(data_g, params, shares_by_state):
    df_shares = pd.DataFrame.from_dict(
        np.array(shares_by_state)
    )
    df_shares['no_enter'] = 1 - df_shares[0] - df_shares[1] - df_shares[2]
    df_shares['number_enter'] = (
        params['d_0'] + data_g * params['d_1']
    ) * (1 - df_shares['no_enter'])
    total_entry_cost = df_shares['number_enter'] * params['c']
    entry_cost = copy.deepcopy(total_entry_cost)
    n_enter = copy.deepcopy(
        df_shares['number_enter'])

    return df_shares, entry_cost, n_enter


def get_counterfactual_decomposition(df_counterfactuals, gas_price, counter,
                                     comparison_by_counter):
    initial = (
            df_counterfactuals[f'{comparison_by_counter}_total_value']
            - df_counterfactuals['opex']
            - df_counterfactuals[f'{comparison_by_counter}_entry_cost']
    )
    entry = (
            df_counterfactuals[f'{comparison_by_counter}_total_value']
            - df_counterfactuals['opex']
            - df_counterfactuals[f'{counter}_entry_cost']
    )
    quality = (
                df_counterfactuals[f'{counter}_total_value'] * (
                df_counterfactuals[f'{comparison_by_counter}_av_utilization']
                / df_counterfactuals[f'{counter}_av_utilization']
            )
            - df_counterfactuals['opex']
            - df_counterfactuals[f'{counter}_entry_cost']
    )
    final = (
            df_counterfactuals[f'{counter}_total_value']
            - df_counterfactuals['opex']
            - df_counterfactuals[f'{counter}_entry_cost']
    )

    series = {
        'final': final,
        'entry': entry,
        'quality': quality,
        'initial': initial,
        'gas_price': gas_price
    }

    return series
